#coding=utf-8
import os
import time
import torch
import torch.nn as nn
import utils
from torch.autograd import Variable


def instance_bce_with_logits(logits, labels, model_type):
    assert logits.dim() == 2

    if model_type == 'BP':
        loss = nn.functional.binary_cross_entropy_with_logits(logits, labels) #二元交叉熵log
        loss *= labels.size(1)
    else:
        loss = nn.functional.nll_loss(logits, labels)
    return loss


def compute_score_with_logits(logits, labels):
    logits = torch.max(logits, 1)[1].data # argmax
    one_hots = torch.zeros(*labels.size()).cuda()
    one_hots.scatter_(1, logits.view(-1, 1), 1)
    '''
    将src中的所有值按照index确定的索引写入本tensor中。其中索引是根据给定的dimension，dim按照gather()描述的规则来确定。
    注意，index的值必须是在_0_到_(self.size(dim)-1)_之间，
    参数： - input (Tensor)-源tensor - dim (int)-索引的轴向 - index (LongTensor)-散射元素的索引指数 - src (Tensor or float)-散射的源元素
    '''
    scores = (one_hots * labels)
    return scores


def train(model, train_loader, eval_loader, num_epochs, output, model_type):
    utils.create_dir(output)
    optim = torch.optim.Adamax(model.parameters())
    logger = utils.Logger(os.path.join(output, 'log.txt'))
    best_eval_score = 0

    for epoch in range(num_epochs):
        total_loss = 0
        train_score = 0
        t = time.time()

        for i, (v, b, q, a) in enumerate(train_loader):
            v = Variable(v).cuda()
            b = Variable(b).cuda()
            q = Variable(q).cuda()
            a = Variable(a).cuda()

            pred = model(v, b, q, a)
            loss = instance_bce_with_logits(pred, a, model_type)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.25) #
            optim.step()
            optim.zero_grad()

            batch_score = compute_score_with_logits(pred, a.data).sum()
            total_loss += loss.item() * v.size(0) #have changed  loss.data[0]
            train_score += batch_score

        total_loss /= len(train_loader.dataset)
        train_score = 100 * train_score / len(train_loader.dataset)
        model.train(False)
        eval_score, bound = evaluate(model, eval_loader)
        model.train(True)

        logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t))
        logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score))
        logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound))


        if eval_score > best_eval_score:
            model_path = os.path.join(output, 'model.pth')
            torch.save(model.state_dict(), model_path) #保存模型参数
            best_eval_score = eval_score


def evaluate(model, dataloader):
    score = 0
    upper_bound = 0
    num_data = 0
    for v, b, q, a in iter(dataloader):

        v = Variable(v).cuda()
        b = Variable(b).cuda()
        q = Variable(q).cuda()
        with torch.no_grad():
            pred = model(v, b, q, None)
        batch_score = compute_score_with_logits(pred, a.cuda()).sum()
        score += batch_score
        upper_bound += (a.max(1)[0]).sum()
        num_data += pred.size(0)

    score = score / len(dataloader.dataset)
    upper_bound = upper_bound / len(dataloader.dataset)
    return score, upper_bound
